from torch import nn

from otherutils.SFPOOL import *
from .colorPrint import *


# tools

def convert_maxpool2d_to_softpool2d(model):
    '''
    将模型内的maxpool2d转换成softpool2d， kernelsize可以更改，这样的修改可以提高模型性能
    :param model: 用于替换maxpool2d的模型
    :return:
    '''
    for child_name, child in model.named_children():
        if isinstance(child, nn.MaxPool2d):
            setattr(model, child_name, SoftPooling2D())
        else:
            convert_maxpool2d_to_softpool2d(child)

# model = VGGNet16(1000)
# convert_maxpool2d_to_softpool2d(model)
# print(model)
